from typing import Iterable, Optional
from torch import nn

import torch

class vanilla_extractor(nn.Module):
    def __init__(self, input_dim, low_dim, drop_out):
        super(vanilla_extractor, self).__init__()
        self.input_dim = input_dim
        self.low_dim = low_dim
        self.drop_out = drop_out
        self.dim_reduction_layer = nn.Sequential(
            nn.Linear(self.input_dim, self.low_dim),
            nn.Dropout(self.drop_out),
            nn.ReLU(),
            nn.BatchNorm1d(self.low_dim)
        )
    
    def forward(self, x_ori):
        batch_size = x_ori.shape[0]
        x_ori = torch.reshape(x_ori, (-1, self.input_dim))
        x_latent = self.dim_reduction_layer(x_ori)
        x_latent = torch.reshape(x_latent, (batch_size, -1, self.low_dim))
        return x_latent